{
 "cells": [
  {
   "cell_type": "code",
   "id": "initial_id",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "is_executing": true
    },
    "ExecuteTime": {
     "start_time": "2024-06-05T08:07:14.903735Z"
    }
   },
   "source": [
    "\"\"\"\n",
    " Sample script using EEGNet to classify Event-Related Potential (ERP) EEG data\n",
    " from a four-class classification task, using the sample dataset provided in\n",
    " the MNE [1, 2] package:\n",
    "     https://martinos.org/mne/stable/manual/sample_dataset.html#ch-sample-data\n",
    "   \n",
    " The four classes used from this dataset are:\n",
    "     LA: Left-ear auditory stimulation\n",
    "     RA: Right-ear auditory stimulation\n",
    "     LV: Left visual field stimulation\n",
    "     RV: Right visual field stimulation\n",
    "\n",
    " The code to process, filter and epoch the data are originally from Alexandre\n",
    " Barachant's PyRiemann [3] package, released under the BSD 3-clause. A copy of \n",
    " the BSD 3-clause license has been provided together with this software to \n",
    " comply with software licensing requirements. \n",
    " \n",
    " When you first run this script, MNE will download the dataset and prompt you\n",
    " to confirm the download location (defaults to ~/mne_data). Follow the prompts\n",
    " to continue. The dataset size is approx. 1.5GB download. \n",
    " \n",
    " For comparative purposes you can also compare EEGNet performance to using \n",
    " Riemannian geometric approaches with xDAWN spatial filtering [4-8] using \n",
    " PyRiemann (code provided below).\n",
    "\n",
    " [1] A. Gramfort, M. Luessi, E. Larson, D. Engemann, D. Strohmeier, C. Brodbeck,\n",
    "     L. Parkkonen, M. Hämäläinen, MNE software for processing MEG and EEG data, \n",
    "     NeuroImage, Volume 86, 1 February 2014, Pages 446-460, ISSN 1053-8119.\n",
    "\n",
    " [2] A. Gramfort, M. Luessi, E. Larson, D. Engemann, D. Strohmeier, C. Brodbeck, \n",
    "     R. Goj, M. Jas, T. Brooks, L. Parkkonen, M. Hämäläinen, MEG and EEG data \n",
    "     analysis with MNE-Python, Frontiers in Neuroscience, Volume 7, 2013.\n",
    "\n",
    " [3] https://github.com/alexandrebarachant/pyRiemann. \n",
    "\n",
    " [4] A. Barachant, M. Congedo ,\"A Plug&Play P300 BCI Using Information Geometry\"\n",
    "     arXiv:1409.0107. link\n",
    "\n",
    " [5] M. Congedo, A. Barachant, A. Andreev ,\"A New generation of Brain-Computer \n",
    "     Interface Based on Riemannian Geometry\", arXiv: 1310.8115.\n",
    "\n",
    " [6] A. Barachant and S. Bonnet, \"Channel selection procedure using riemannian \n",
    "     distance for BCI applications,\" in 2011 5th International IEEE/EMBS \n",
    "     Conference on Neural Engineering (NER), 2011, 348-351.\n",
    "\n",
    " [7] A. Barachant, S. Bonnet, M. Congedo and C. Jutten, “Multiclass \n",
    "     Brain-Computer Interface Classification by Riemannian Geometry,” in IEEE \n",
    "     Transactions on Biomedical Engineering, vol. 59, no. 4, p. 920-928, 2012.\n",
    "\n",
    " [8] A. Barachant, S. Bonnet, M. Congedo and C. Jutten, “Classification of \n",
    "     covariance matrices using a Riemannian-based kernel for BCI applications“, \n",
    "     in NeuroComputing, vol. 112, p. 172-178, 2013.\n",
    "\n",
    "\n",
    " Portions of this project are works of the United States Government and are not\n",
    " subject to domestic copyright protection under 17 USC Sec. 105.  Those \n",
    " portions are released world-wide under the terms of the Creative Commons Zero \n",
    " 1.0 (CC0) license.  \n",
    " \n",
    " Other portions of this project are subject to domestic copyright protection \n",
    " under 17 USC Sec. 105.  Those portions are licensed under the Apache 2.0 \n",
    " license.  The complete text of the license governing this material is in \n",
    " the file labeled LICENSE.TXT that is a part of this project's official \n",
    " distribution. \n",
    "\"\"\"\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "# mne imports\n",
    "import mne\n",
    "from mne import io\n",
    "from mne.datasets import sample\n",
    "\n",
    "# EEGNet-specific imports\n",
    "from EEGModels import EEGNet\n",
    "from tensorflow.keras import utils as np_utils\n",
    "from tensorflow.keras.callbacks import ModelCheckpoint\n",
    "from tensorflow.keras import backend as K\n",
    "\n",
    "# PyRiemann imports\n",
    "from pyriemann.estimation import XdawnCovariances\n",
    "from pyriemann.tangentspace import TangentSpace\n",
    "from pyriemann.utils.viz import plot_confusion_matrix\n",
    "from sklearn.pipeline import make_pipeline\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "\n",
    "# tools for plotting confusion matrices\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "# while the default tensorflow ordering is 'channels_last' we set it here\n",
    "# to be explicit in case if the user has changed the default ordering\n",
    "K.set_image_data_format('channels_last')\n",
    "\n",
    "##################### Process, filter and epoch the data ######################\n",
    "data_path = sample.data_path()\n",
    "\n",
    "# Set parameters and read data\n",
    "raw_fname = data_path  / 'MEG/sample/sample_audvis_filt-0-40_raw.fif'\n",
    "event_fname = data_path / 'MEG/sample/sample_audvis_filt-0-40_raw-eve.fif'\n",
    "tmin, tmax = -0., 1\n",
    "event_id = dict(aud_l=1, aud_r=2, vis_l=3, vis_r=4)\n",
    "\n",
    "# Setup for reading the raw data\n",
    "raw = io.Raw(raw_fname, preload=True, verbose=False)\n",
    "raw.filter(2, None, method='iir')  # replace baselining with high-pass\n",
    "events = mne.read_events(event_fname)\n",
    "\n",
    "raw.info['bads'] = ['MEG 2443']  # set bad channels\n",
    "picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,\n",
    "                       exclude='bads')\n",
    "\n",
    "# Read epochs\n",
    "epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=False,\n",
    "                    picks=picks, baseline=None, preload=True, verbose=False)\n",
    "labels = epochs.events[:, -1]\n",
    "\n",
    "# extract raw data. scale by 1000 due to scaling sensitivity in deep learning\n",
    "X = epochs.get_data() * 1000  # format is in (trials, channels, samples)\n",
    "y = labels\n",
    "\n",
    "kernels, chans, samples = 1, 60, 151\n",
    "\n",
    "# take 50/25/25 percent of the data to train/validate/test\n",
    "X_train = X[0:144, ]\n",
    "Y_train = y[0:144]\n",
    "X_validate = X[144:216, ]\n",
    "Y_validate = y[144:216]\n",
    "X_test = X[216:, ]\n",
    "Y_test = y[216:]\n",
    "\n",
    "############################# EEGNet portion ##################################\n",
    "\n",
    "# convert labels to one-hot encodings.\n",
    "Y_train = np_utils.to_categorical(Y_train - 1)\n",
    "Y_validate = np_utils.to_categorical(Y_validate - 1)\n",
    "Y_test = np_utils.to_categorical(Y_test - 1)\n",
    "\n",
    "# convert data to NHWC (trials, channels, samples, kernels) format. Data \n",
    "# contains 60 channels and 151 time-points. Set the number of kernels to 1.\n",
    "X_train = X_train.reshape(X_train.shape[0], chans, samples, kernels)\n",
    "X_validate = X_validate.reshape(X_validate.shape[0], chans, samples, kernels)\n",
    "X_test = X_test.reshape(X_test.shape[0], chans, samples, kernels)\n",
    "\n",
    "print('X_train shape:', X_train.shape)\n",
    "print(X_train.shape[0], 'train samples')\n",
    "print(X_test.shape[0], 'test samples')\n",
    "\n",
    "# configure the EEGNet-8,2,16 model with kernel length of 32 samples (other \n",
    "# model configurations may do better, but this is a good starting point)\n",
    "model = EEGNet(nb_classes=4, Chans=chans, Samples=samples,\n",
    "               dropoutRate=0.5, kernLength=32, F1=8, D=2, F2=16,\n",
    "               dropoutType='Dropout')\n",
    "\n",
    "# compile the model and set the optimizers\n",
    "model.compile(loss='categorical_crossentropy', optimizer='adam',\n",
    "              metrics=['accuracy'])\n",
    "\n",
    "# count number of parameters in the model\n",
    "numParams = model.count_params()\n",
    "\n",
    "# set a valid path for your system to record model checkpoints\n",
    "checkpointer = ModelCheckpoint(filepath='/tmp/checkpoint.weights.h5', verbose=1,\n",
    "                               save_best_only=True,save_weights_only=True)\n",
    "\n",
    "###############################################################################\n",
    "# if the classification task was imbalanced (significantly more trials in one\n",
    "# class versus the others) you can assign a weight to each class during \n",
    "# optimization to balance it out. This data is approximately balanced so we \n",
    "# don't need to do this, but is shown here for illustration/completeness. \n",
    "###############################################################################\n",
    "\n",
    "# the syntax is {class_1:weight_1, class_2:weight_2,...}. Here just setting\n",
    "# the weights all to be 1\n",
    "class_weights = {0: 1, 1: 1, 2: 1, 3: 1}\n",
    "\n",
    "################################################################################\n",
    "# fit the model. Due to very small sample sizes this can get\n",
    "# pretty noisy run-to-run, but most runs should be comparable to xDAWN + \n",
    "# Riemannian geometry classification (below)\n",
    "################################################################################\n",
    "fittedModel = model.fit(X_train, Y_train, batch_size=16, epochs=300,\n",
    "                        verbose=2, validation_data=(X_validate, Y_validate),\n",
    "                        callbacks=[checkpointer], class_weight=class_weights)\n",
    "\n",
    "# load optimal weights\n",
    "model.load_weights('/tmp/checkpoint.weights.h5')\n",
    "\n",
    "###############################################################################\n",
    "# can alternatively used the weights provided in the repo. If so it should get\n",
    "# you 93% accuracy. Change the WEIGHTS_PATH variable to wherever it is on your\n",
    "# system.\n",
    "###############################################################################\n",
    "\n",
    "# WEIGHTS_PATH = /path/to/EEGNet-8-2-weights.h5 \n",
    "# model.load_weights(WEIGHTS_PATH)\n",
    "\n",
    "###############################################################################\n",
    "# make prediction on test set.\n",
    "###############################################################################\n",
    "\n",
    "probs = model.predict(X_test)\n",
    "preds = probs.argmax(axis=-1)\n",
    "acc = np.mean(preds == Y_test.argmax(axis=-1))\n",
    "print(\"Classification accuracy: %f \" % (acc))\n",
    "\n",
    "############################# PyRiemann Portion ##############################\n",
    "\n",
    "# code is taken from PyRiemann's ERP sample script, which is decoding in \n",
    "# the tangent space with a logistic regression\n",
    "\n",
    "n_components = 2  # pick some components\n",
    "\n",
    "# set up sklearn pipeline\n",
    "clf = make_pipeline(XdawnCovariances(n_components),\n",
    "                    TangentSpace(metric='riemann'),\n",
    "                    LogisticRegression())\n",
    "\n",
    "preds_rg = np.zeros(len(Y_test))\n",
    "\n",
    "# reshape back to (trials, channels, samples)\n",
    "X_train = X_train.reshape(X_train.shape[0], chans, samples)\n",
    "X_test = X_test.reshape(X_test.shape[0], chans, samples)\n",
    "\n",
    "# train a classifier with xDAWN spatial filtering + Riemannian Geometry (RG)\n",
    "# labels need to be back in single-column format\n",
    "clf.fit(X_train, Y_train.argmax(axis=-1))\n",
    "preds_rg = clf.predict(X_test)\n",
    "\n",
    "# Printing the results\n",
    "acc2 = np.mean(preds_rg == Y_test.argmax(axis=-1))\n",
    "print(\"Classification accuracy: %f \" % (acc2))\n",
    "\n",
    "# plot the confusion matrices for both classifiers\n",
    "names = ['audio left', 'audio right', 'vis left', 'vis right']\n",
    "plt.figure(0)\n",
    "plot_confusion_matrix(preds, Y_test.argmax(axis=-1), names, title='EEGNet-8,2')\n",
    "\n",
    "plt.figure(1)\n",
    "plot_confusion_matrix(preds_rg, Y_test.argmax(axis=-1), names, title='xDAWN + RG')\n"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Filtering raw data in 1 contiguous segment\n",
      "Setting up high-pass filter at 2 Hz\n",
      "\n",
      "IIR filter parameters\n",
      "---------------------\n",
      "Butterworth highpass zero-phase (two-pass forward and reverse) non-causal filter:\n",
      "- Filter order 8 (effective, after forward-backward)\n",
      "- Cutoff at 2.00 Hz: -6.02 dB\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\mimovr\\AppData\\Local\\Temp\\ipykernel_16404\\2399109752.py:118: FutureWarning: The current default of copy=False will change to copy=True in 1.7. Set the value of copy explicitly to avoid this warning\n",
      "  X = epochs.get_data() * 1000  # format is in (trials, channels, samples)\n",
      "E:\\xialugui\\PoetryCache\\virtualenvs\\eegnet--dnrzntS-py3.11\\Lib\\site-packages\\keras\\src\\layers\\convolutional\\base_conv.py:107: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.\n",
      "  super().__init__(activity_regularizer=activity_regularizer, **kwargs)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "X_train shape: (144, 60, 151, 1)\n",
      "144 train samples\n",
      "72 test samples\n",
      "Epoch 1/300\n"
     ]
    }
   ],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
